import json
import os
from random import randrange
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainerCallback, TrainerState, TrainerControl
from datasets import load_dataset, DatasetDict, Dataset
import pickle
from functools import partial
from tqdm import tqdm
from trl import SFTConfig, SFTTrainer
from nltk.translate.bleu_score import sentence_bleu
from rouge_score import rouge_scorer
import numpy as np
from peft import LoraConfig
import re
import argparse
from transformers import DataCollatorForSeq2Seq, DataCollatorWithPadding
import warnings
import warnings
from accelerate import Accelerator
from accelerate.utils import gather_object
from codebleu import calc_codebleu
import os
import torch.distributed as dist
from datetime import timedelta
import time
from utils_for_llm import *
from train_bf16 import predict_on_validation_BATCH

warnings.filterwarnings("ignore")

os.environ['TORCH_NCCL_BLOCKING_WAIT'] = '1'
os.environ['TORCH_NCCL_ASYNC_ERROR_HANDLING'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = "False"

if os.getenv('PYCHARM_HOSTED') != '1':
    if not dist.is_initialized():
        dist.init_process_group(backend='nccl', timeout=timedelta(hours=6))

# Initialize the Accelerator
accelerator = Accelerator(mixed_precision='bf16')



parser = argparse.ArgumentParser()
parser.add_argument("--task", default="code_generation", type=str)
parser.add_argument("--load_path", default="/workflowllm/model_related/Llama3.1-8B-workflow-code_generation-neft5.0/best", type=str)
parser.add_argument("--model_version", default=3.1, type=float)
parser.add_argument("--model_size", default=8, type=float)
args = parser.parse_args()
task = args.task # or task_breakdown


with open('../data/statistics.pkl', 'rb') as fp:
    stat = pickle.load(fp)
with open('../data/identifier2python.pkl', 'rb') as fp:
    identifier2python = pickle.load(fp)

max_seq_length = 8192
if task == "code_generation":
    format_instruction = format_instruction_with_code
    BATCH_SIZE = 2
    target_col = "code"
elif task == "task_breakdown":
    format_instruction = format_instruction_without_code
    max_seq_length = 768
    BATCH_SIZE = 8
    target_col = "description"
else:
    raise Exception(f'{task} is not defined.')

eval_batch_size = BATCH_SIZE
# =================================================


# Load and preprocess data
if __name__ == "__main__":
    with open('data/test_data.json', 'r') as fp:
        data = json.load(fp)

    model_id = f"/Pretrained_Language_Models/Meta-Llama-{args.model_version}-{args.model_size}B-Instruct"
    if accelerator.is_main_process:
        print('eval_batch_size: ', eval_batch_size)
        print('model_id:', model_id)
        print('load_path:', args.load_path)
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    random_state = 42

    if args.load_path == "":
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            use_cache=False,
            attn_implementation="flash_attention_2",
            device_map={"": accelerator.process_index},
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            args.load_path,
            torch_dtype=torch.bfloat16,
            use_cache=False,
            attn_implementation="flash_attention_2",
            device_map={"": accelerator.process_index},
        )
    model.config.pretraining_tp = 1

    # sync GPUs and start the timer
    accelerator.wait_for_everyone()
    start = time.time()
    # Split the data across processes
    with accelerator.split_between_processes(data) as eval_dataset:
        infer_result = predict_on_validation_BATCH(model, tokenizer, eval_dataset, batch_size=eval_batch_size, external_data=True)

    # Gather results from all processes
    infer_result = gather_object(infer_result)
    timediff = time.time() - start

    minutes, seconds = divmod(timediff, 60)
    hours, minutes = divmod(minutes, 60)


    # Only save the results on the main process
    if accelerator.is_main_process:
        dump_path = f'synthesized-{args.task}_result.json'
        print(f"Time difference: {int(hours)} hours, {int(minutes)} minutes, {seconds:.2f} seconds")
        with open(dump_path, 'w') as fp:
            json.dump(infer_result, fp, indent=4)
        print('dump_path:', dump_path)